{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Testing Roberta.\n",
    "\n",
    "This is a quick test to see how Roberta performs on sentences it has not seen before and to check if it believes the sentences are contradicting each other, are neutral or are in contextual agreement or entailment. The model was already fine-tuned by the team at Facebook for the MNLI dataset task. I was curious to see how it did on a small sample of sentences it had not seen before.\n",
    "\n",
    "The model has three possible answers with respect to each pair of sentences it is asked to classify.  \n",
    "\n",
    "1. The sentence pairs contradict each other.\n",
    "2. They are neutral with respect to each other.\n",
    "3. They are in agreement and follow on logically from the first to the second sentence,(entailment).\n",
    "\n",
    "Results are below.  \n",
    "\n",
    "##### Aug 2019"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "#Get the Roberta model from Facebook hub. Using the Roberta pre-trained large MNLI model.\n",
    "roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')\n",
    "\n",
    "#Function to test the model \n",
    "def hello_roberta():     \n",
    "    #Four sentence pairs to test the model on.\n",
    "    s1=[('Jamie worked for a bank as of August 2019.'),('There was a meeting in the boardroom.'),('The client asked me to do it.'),('The client wanted me to buy it on the open and call her stright away.')] \n",
    "    s2=[('Jamie was fired from the bank in July 2019.'),(\"Risk held their weekly update in the boardroom.\"),('I didnt get permission from the client.'),('Once I bought it for the client I called the client immediately.')] \n",
    "    for i in range(4): \n",
    "        label= {0:'Contradiction',1:'Neutral',2:'Entailment'}\n",
    "        pred=[]\n",
    "        tokens = roberta.encode(s1[i], s2[i])\n",
    "        prediction = roberta.predict('mnli', tokens).argmax().item()     \n",
    "        for k,v in label.items():\n",
    "            if prediction == k:\n",
    "                pred.append(v)\n",
    "                print(\"\")            \n",
    "                print(s1[i],s2[i]) \n",
    "                print(\"Roberta's prediction:\",pred)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Jamie worked for a bank as of August 2019. Jamie was fired from the bank in July 2019.\n",
      "Roberta's prediction: ['Contradiction']\n",
      "\n",
      "There was a meeting in the boardroom. Risk held their weekly update in the boardroom.\n",
      "Roberta's prediction: ['Neutral']\n",
      "\n",
      "The client asked me to do it. I didnt get permission from the client.\n",
      "Roberta's prediction: ['Contradiction']\n",
      "\n",
      "The client wanted me to buy it on the open and call her stright away. Once I bought it for the client I called the client immediately.\n",
      "Roberta's prediction: ['Entailment']\n"
     ]
    }
   ],
   "source": [
    "#Call the model \n",
    "hello_roberta()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
